Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python bindings and tests for Triu #3637

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open

Python bindings and tests for Triu #3637

wants to merge 10 commits into from

Conversation

protonu
Copy link
Collaborator

@protonu protonu commented Dec 23, 2024

This PR exposes the Triu C++ API which was added in the PR (#3631)

@protonu protonu force-pushed the pbasu_triu_py_layer branch 3 times, most recently from 6ae6033 to 6d2e7ab Compare December 23, 2024 21:16
@protonu protonu force-pushed the pbasu_iota_experiment branch from 260141f to 5f82c65 Compare December 23, 2024 21:33
@protonu protonu force-pushed the pbasu_triu_py_layer branch from 6d2e7ab to a3ab248 Compare December 23, 2024 21:43
@protonu protonu force-pushed the pbasu_iota_experiment branch from 0b42b74 to b734bf6 Compare January 2, 2025 20:12
@protonu protonu force-pushed the pbasu_triu_py_layer branch 3 times, most recently from e134264 to d66b838 Compare January 4, 2025 15:47
@protonu protonu marked this pull request as ready for review January 4, 2025 15:48
@protonu
Copy link
Collaborator Author

protonu commented Jan 4, 2025

!test

@protonu protonu changed the title [WIP] Python bindings and tests for Triu Python bindings and tests for Triu Jan 4, 2025
csrc/ops/composite.h Outdated Show resolved Hide resolved
csrc/python_frontend/python_bindings.cpp Outdated Show resolved Hide resolved
tests/python/opinfo_input_generators.py Show resolved Hide resolved
csrc/python_frontend/python_bindings.cpp Outdated Show resolved Hide resolved
protonu added a commit that referenced this pull request Jan 6, 2025
This is a C++ API to implement `triu`. The PR
(#3637) stacked on top of this used
it to create a Python interface.

Another way of using this may be to use the components of `triu` such as
`iota`, `broadcast`, `le` and `where` from Thunder directly bypassing
the need for a C++ implementation. As future work this commit/PR may be
removed in favor of a Thunder only implementation.
Base automatically changed from pbasu_iota_experiment to main January 6, 2025 19:51
@protonu protonu force-pushed the pbasu_triu_py_layer branch from 1794bac to 0426b46 Compare January 7, 2025 00:06
@protonu
Copy link
Collaborator Author

protonu commented Jan 7, 2025

!test

@protonu protonu force-pushed the pbasu_triu_py_layer branch from 0426b46 to 8339c17 Compare January 7, 2025 00:08
@protonu
Copy link
Collaborator Author

protonu commented Jan 7, 2025

@jacobhinkle can you take another look?

@@ -28,10 +28,13 @@ def parse_inputs_fusion_definition(fd: FusionDefinition, opinfo: OpInfo, *args):
)

num_symbolic_parameters = len(symbolic_parameter_list)
assert num_symbolic_parameters == len(
assert num_symbolic_parameters >= len(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdspring1 does this change look reasonable?

@@ -63,7 +63,9 @@ def parse_args_fusion_execution(opinfo: OpInfo, *args):
else [ArgumentType.Symbolic] * len(args)
)

assert len(symbolic_parameter_list) == len(args)
assert len(symbolic_parameter_list) >= len(args)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdspring1 does this change look okay? I had to make this change we could be generating args of different lengths.

Comment on lines +35 to +37
if num_symbolic_parameters > len(args):
symbolic_parameter_list = symbolic_parameter_list[: len(args)]

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt this is necessary as zip will truncate to the length of the shortest argument.

},
py::arg("input"),
py::arg("diagonal") = 0,
py::return_value_policy::reference);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bonus points for adding a pytorch or numpy style docstring. It should go as the last argument in the def function.

Reference: https://pybind11.readthedocs.io/en/stable/advanced/misc.html#generating-documentation-using-sphinx

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants